{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.autograd import Variable\n",
    "import numpy as np\n",
    "import torch.functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "sentences = ['ich mochte ein bier P', 'S i want a beer', 'i want a beer E']#data\n",
    "words = ' '.join(sentences)\n",
    "words = list(set(words.split()))\n",
    "word2id = {w:i for i,w in enumerate(words)}\n",
    "id2word = {i:w for i,w in enumerate(words)}\n",
    "\n",
    "n_class = len(words)\n",
    "n_step = 5\n",
    "n_hidden = 128"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def get_batch(sentences):\n",
    "    input_batch = [np.eye(n_class)[[word2id[w] for w in sentences[0].split()]]]#[batch_size,n_step,hidden_size]\n",
    "    output_batch = [np.eye(n_class)[[word2id[w] for w in sentences[1].split()]]]#[batch_size,n_step,hidden_size]\n",
    "    target_batch = [[word2id[w] for w in sentences[2].split()]]\n",
    "    return Variable(torch.Tensor(input_batch)),Variable(torch.Tensor(output_batch)),Variable(torch.Tensor(target_batch).type(torch.LongTensor))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class Seq2seq_att(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Seq2seq_att,self).__init__()\n",
    "        self.enc_cell = nn.RNN(input_size=n_class,hidden_size=n_hidden,dropout=0.5)#[n_step,batch_size,emb_size]\n",
    "        self.dec_cell = nn.RNN(input_size=n_class,hidden_size=n_hidden,dropout=0.5)#[n_step,batch_size,emb_size]\n",
    "        self.attn = nn.Linear(n_hidden,n_hidden)\n",
    "        self.out = nn.Linear(n_hidden*2,n_class)\n",
    "        \n",
    "    def get_att_weight(self,enc_output,dec_output):#[n_step,batch_size,emb_size] [1,batch_size,emb_size]\n",
    "        n_step = len(enc_output)\n",
    "        att_scores = Variable(torch.zeros(n_step))\n",
    "        for i in range(n_step):\n",
    "            att_scores[i] = self.get_att_score(enc_output[i],dec_output)\n",
    "        return F.softmax(att_scores).view(1,1,-1)\n",
    "        \n",
    "    def get_att_score(self,enc_output,dec_output):\n",
    "        score = self.attn(enc_output)\n",
    "        return torch.dot(score.view(-1),dec_output.view(-1))\n",
    "    \n",
    "    def forward(self,enc_input,enc_hidden,dec_input):\n",
    "        enc_input = enc_input.transpose(0,1)#[n_step,batch_size,hidden_size]\n",
    "        #print(enc_input.size())\n",
    "        dec_input = dec_input.transpose(0,1)\n",
    "        #print(dec_input.size())\n",
    "        enc_output,hidden = self.enc_cell(enc_input,enc_hidden)\n",
    "        train_att = []\n",
    "        hidden = enc_hidden\n",
    "        n_step = len(dec_input)\n",
    "        model = Variable(torch.empty(n_step,1,n_class))\n",
    "        for i in range(n_step):\n",
    "            dec_output,hidden = self.dec_cell(dec_input[i].unsqueeze(0),hidden)#[1,batch_size,emb_size]\n",
    "            attn_weight = self.get_att_weight(enc_output,dec_output)\n",
    "            train_att.append(attn_weight.squeeze().data.numpy())\n",
    "            \n",
    "            context = torch.matmul(attn_weight,enc_output.transpose(0,1))\n",
    "            dec_output = dec_output.squeeze(0)\n",
    "            context = context.squeeze(1)\n",
    "            #print(torch.cat((dec_output, context), 1).size())\n",
    "            model[i] = self.out(torch.cat((dec_output, context), 1))\n",
    "        return model.transpose(0, 1).squeeze(0), train_att"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/caijie/anaconda3/lib/python3.6/site-packages/torch/nn/modules/rnn.py:38: UserWarning: dropout option adds dropout after all but last recurrent layer, so non-zero dropout expects num_layers greater than 1, but got dropout=0.5 and num_layers=1\n",
      "  \"num_layers={}\".format(dropout, num_layers))\n",
      "/Users/caijie/anaconda3/lib/python3.6/site-packages/ipykernel_launcher.py:14: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
      "  \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0400 cost = 0.000966\n",
      "Epoch: 0800 cost = 0.000305\n",
      "Epoch: 1200 cost = 0.000150\n",
      "Epoch: 1600 cost = 0.000089\n",
      "Epoch: 2000 cost = 0.000057\n",
      "ich mochte ein bier P -> ['i', 'want', 'a', 'beer', 'E']\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUQAAAE0CAYAAABUyEoHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAADl5JREFUeJzt3XuMpXddx/H3x27p2tKlBQoUelnb\nQi0gpWalpeiyCQLR1EKIIbUkihA2BkkwxASjImtCrATaRCNNWAjtPyoJag0XRULD/kFV6na5ieWS\nlu2FXmiFZdkt3Qv9+sfzHTgOuzNnysx5Zmffr+RkzznPOfN8f3Nm3n2eM9PdVBWSJPiZsQeQpNXC\nIEpSM4iS1AyiJDWDKEnNIEpSO26DmGRjkql+52gpj13Nkrw+yY6x51hOx9Jrk2RLkt0LbD9m1vJ4\n9eeg+nIoyX8n+fWx55pz3AYRuBs4fewhVpsku5NsGXuOJVhLr+NaWstC9jOs8xzgr4CPJDl73JEG\nx20Qq+qxqtoz9hz66ayl13EtrWUxVbWnqu6vqg8AdwJbRh4JOI6DeKTTkyQvSHJLkn1J/jPJJfO2\nX5bk9iR7k1wzgxl3JLkhyUNJ/q6v701yRZLNSb6Q5Lu97bSJ5702ydeSfC/JPyZ5yryP+6YkD/bl\nNX3fp/vzcS7wmT6l2TbxnF9M8u9Jvt9/Pn+l1z+x7/OTfKrX/qUkmye2Hel13Njzz/T1mlKSXJ9k\nf3+tXTCx4YinzIusf1uSG5O8PMmuJNfNaiHL6BBw4thDAFBVx+UF2Dgs/0e3TwXuB94BPAv4S+Cb\nk48FbgVeBLwaeAw4b4Vn3NGXK3r/bwFuBG4A9gFvAs4DPg78cz/nRQynJK9iiNungBt62+uBh4F/\nBS7oNd7d204BTmM4bbuir6/vbU/q++c+5jZg14xep3XAV4A3MJxibQXuA0460us45us1xVq29Fzv\n6xk/CHzuaF+TU65/G7AT+DLwm8AFY65xys/BvonbLwd+AFw49mxVZRAnbv8W8I2J26cCVzH8l2vu\nG+zKie33AS9d4Rl3AL83sf/1/Q1wALh54nHn9PZnAO8HPjCx7XzglX399f3F97S+/ZwjfAPuBrbM\nu+91wGFgT1/29v5OmcHr9JKO2Z6JSwHPPdLrOHnfrF+vKdayBTgIPLFvP73nPGeBtSy2/rmvh1Fj\nv8TPQfU6Hu2vpa1jzzV3WYfmnM0QAwCq6vvAhwGSzN29Y+LxB4Gw8h6dmOnRnuVm4FsT99+d5DBD\nGM8GbpnYdgdwx8THu72qvt3XD045w1kMR1tXH222FXQWcC+wed7990/x3B0T12f1ei3m4araB1BV\nDyY5BJzJcAR+JNOs/2NVdeeyT7pyHgFeyHCqfF91KVeD4/Y9xCO4h+F0EIAk65J8Mcmz5+6rqr2j\nTPaTXsZwqgxAknMYTq3uZljHxoltlyb57MRzF1vDY/xkOO5l+Ma8q6p2A98B/ojhKHql3QucATzU\n+74LeFvPs6BV9HpNekqSkwGSPJXhDOTBBR4/zfr3rcyoK6aqandVfWs1xRAM4qRPABuSvCPJs4A/\nZnjv7K5xxzqiW4DL+ocj5wHXAx+tqgcY3mO8OsmrkpzLsI57lvCxvw68Iskzkrys7/sEcDLwZ0nO\nYjhN+5WazU9EP8dw5H5dh/9tDG9lPDCDfa+EJwDX9GvzLuCLLPw1ttbWv6oZxNZHE7/Wl68Br2B4\nD2ra08pZmvvBx1uAXQxHCL8LUFX/AbwReA/DN9sjwO8v4WP/IcPn4B7g3f0x9/Djz81XgU3Aa5Zh\nHYuqqsPAlQzvhX6F4f3MK6tq/yz2vwLuAZ7I8NpcAly90FHSGlz/qpZVdsQqSaPxCFGSmkGUpGYQ\nJakZRElqBlGSmkE8iiRbx55hJbiuY89aXdtqXJdBPLpV92ItE9d17Fmra1t16zKIktSOmV/MfkJO\nqvWcMrP9HeIAJ3LSzPY3K65rGfb1tNl9HQL88Af7OeFnZ7TPU384m/0Ah/c+wroNJ89kX4/ecf/D\nVXXGYo87Zv62m/WcwqU/+l9rpfE8cPXlY4+wYg5v/t7YI6yIr77mz6f6Owk8ZZakZhAlqRlESWoG\nUZKaQZSkZhAlqRlESWoGUZKaQZSkZhAlqRlESWoGUZKaQZSkZhAlqRlESWoGUZKaQZSkZhAlqRlE\nSWoGUZLaqEFMsjHJsfHP/kla88Y+QrwbOH3kGSQJGPmfIa2qx4A9Y84gSXM8ZZakNvYpsyStGqOe\nMi8myVZgK8B6Th55Gklr3ao+Qqyq7VW1qao2nchJY48jaY1b1UGUpFkyiJLUDKIktbF/D3E3kDFn\nkKQ5HiFKUjOIktQMoiQ1gyhJzSBKUjOIktQMoiQ1gyhJzSBKUjOIktQMoiQ1gyhJzSBKUjOIktQM\noiQ1gyhJzSBKUjOIktQMoiQ1gyhJbdR/ZEo6Fp31T3ePPcKKeecf3DT2CCvixVM+ziNESWoGUZKa\nQZSkZhAlqRlESWoGUZKaQZSkZhAlqRlESWoGUZKaQZSkZhAlqRlESWoGUZKaQZSkZhAlqRlESWoG\nUZKaQZSkZhAlqRlESWoGUZKaQZSkZhAlqY0SxCQ3Jtk2xr4l6Wg8QpSktmAQk/xXktcm+eUklWRj\nkrcmuSnJ5Uk+n+SRJLcmeW4/Z2M/9rIktyfZm+Sa3vanSQr4HeCd/bgdK75KSZrCYkeIu4BnAxcB\ntwA/Dzyn7/8IcBNwHnAz8N55z/1rhvD9NvD2JOcB7wFOB/4eeHdfv+JoO0+yNcnOJDsPcWBpK5Ok\nJVq3yPZdwIuBJwMf48dB/BfgUuDbwPOA04AL5z33XVV1K0CSB4Czq+pO4ECSg8CjVbVnoZ1X1XZg\nO8CGPLmWsC5JWrLFjhBvYzhCvJAhghfx4yPENwP3AR8CzgROmPfcHRPXDwL56ceVpJWzWBC/DPwc\nsAG4HbgYOJEhim8FXlBVl9BHcZOqau8CH/cxDKSkVWbBIFbVAYbT4oNVdZjhPb8vMASygFOTXA5c\ny9IC93Vgc5Izk/xSkjMe1/SStIym+bWb2xgCRv+5C/gk8PHe9n7gg8Azkzx9yv2+D9gPfJPhvckN\nS5hZklbEYj9UoareOHH9NyY2XTXvoddOXP9/R4tVtXHe7e+zwE+XJWkM/mK2JDWDKEnNIEpSM4iS\n1AyiJDWDKEnNIEpSM4iS1AyiJDWDKEnNIEpSM4iS1AyiJDWDKEnNIEpSM4iS1AyiJDWDKEnNIEpS\nM4iS1Bb9R6YkzXPC2j2OOFQnjD3CqNbuKytJS2QQJakZRElqBlGSmkGUpGYQJakZRElqBlGSmkGU\npGYQJakZRElqBlGSmkGUpGYQJakZRElqBlGSmkGUpGYQJakZRElqBlGSmkGUpGYQJakZRElqBlGS\nmkGUpDZaEJNcnuTzSR5JcmuS5441iyTBSEFMEuAjwE3AecDNwHuP8LitSXYm2XmIAzOeUtLxZt1I\n+w1wKfBt4HnAacCF8x9UVduB7QAb8uSa5YCSjj+jHCFW1WPAm4H7gA8BZwInjDGLJM0Z65T5pcBb\ngRdU1SX0UaAkjWmsH6psAAo4NcnlwLUMp9GSNJqxgvhJ4OPAbcD7gQ8Cz0zy9JHmkaRxfqhSVYeA\nq+bdfe0Ys0jSHH8xW5KaQZSkZhAlqRlESWoGUZKaQZSkZhAlqRlESWoGUZKaQZSkZhAlqRlESWoG\nUZKaQZSkZhAlqRlESWoGUZKaQZSkZhAlqRlESWqpqrFnmMqmi9fXrf929thjLLtXPvOFY48grXmf\nrn+4rao2LfY4jxAlqRlESWoGUZKaQZSkZhAlqRlESWoGUZKaQZSkZhAlqRlESWoGUZKaQZSkZhAl\nqRlESWoGUZKaQZSkZhAlqRlESWoGUZKaQZSkZhAlqRlESWqLBjHJliS7ZzCLJI3KI0RJagZRktq0\nQUyS65PsT3JLkgv6zvOTfCrJ3iRfSrJ54gkLbduW5MYkL0+yK8l1y7wuSVqyaYN4DlDA84Dbgb9N\nsg74KPBh4PnA3wAfTnLSQtsmPubzgeuAvwCuX4a1SNJPZd2UjzsEvL2q9iX5E+AB4CXARQxRm/Mk\n4Hzg9AW2/U/f/gXgoqq682g7TbIV2ApwzrOmHVWSHp9pK/NwVe0DqKoHkxwCLgPuBTbPe+z9wKsX\n2DbnYwvFsPe1HdgOsOni9TXlrJL0uEx7yvyUJCcDJHkqcCLwWeAM4KGq2g3cBbwNOIshhkfbNmff\nMswvSctm2iA+AbgmybnAu4AvArcCu4HrkpzDELyrGE6nP7fANklalaYN4j3AExlCeAlwdVUdAq5k\neF/wK8DrgCuran9VHT7atmWeX5KWzaLvIVbVDoafMgO8cd62bwC/epTnLbRt21KGlKRZ8BezJakZ\nRElqBlGSmkGUpGYQJakZRElqBlGSmkGUpGYQJakZRElqBlGSmkGUpGYQJakZRElqBlGSmkGUpGYQ\nJakZRElqBlGSmkGUpGYQJakZRElqBlGSmkGUpGYQJakZRElqBlGSmkGUpGYQJakZRElqBlGSmkGU\npGYQJakZRElqBlGSmkGUpGYQJakZRElqBlGSmkGUpGYQJakZRElqBlGSmkGUpDZKEJNsTFJHuGwc\nYx5JAlg38v7PB74zcXvvWINI0thB3FtVe0aeQZIA30OUpB8ZO4h3JNnTl8/M35hka5KdSXY+9L8/\nHGM+SceRsU+ZtwDf7esH5m+squ3AdoBNF6+v2Y0l6Xg0dhDvqaqHR55BkoDxT5kladUY+whxQ5LD\nE7cfqaqDo00j6bg29hHiHQzvIc5d3jDuOJKOZ6McIVbVbiBj7FuSjmbsI0RJWjUMoiQ1gyhJzSBK\nUjOIktQMoiQ1gyhJzSBKUjOIktQMoiQ1gyhJzSBKUjOIktQMoiQ1gyhJzSBKUjOIktQMoiQ1gyhJ\nzSBKUktVjT3DVJI8BNw1w10+FXh4hvubFdd17Fmra5vlus6tqjMWe9AxE8RZS7KzqjaNPcdyc13H\nnrW6ttW4Lk+ZJakZRElqBvHoto89wApxXceetbq2Vbcu30OUpOYRoiQ1gyhJzSBKUjOIktQMoiS1\n/wM5Utetc+EOsQAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 360x360 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "input_batch, output_batch, target_batch = get_batch(sentences)\n",
    "#print(input_batch.size())#21,16\n",
    "# hidden : [num_layers(=1) * num_directions(=1), batch_size, n_hidden]\n",
    "hidden = Variable(torch.zeros(1, 1, n_hidden))\n",
    "\n",
    "model = Seq2seq_att()\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
    "\n",
    "# Train\n",
    "for epoch in range(2000):\n",
    "    optimizer.zero_grad()\n",
    "    output, _ = model(input_batch, hidden, output_batch)\n",
    "\n",
    "    loss = criterion(output, target_batch.squeeze(0))\n",
    "    if (epoch + 1) % 400 == 0:\n",
    "        print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))\n",
    "\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "# Test\n",
    "predict, trained_attn = model(input_batch, hidden, output_batch)\n",
    "predict = predict.data.max(1, keepdim=True)[1]\n",
    "print(sentences[0], '->', [number_dict[n.item()] for n in predict.squeeze()])\n",
    "\n",
    "# Show Attention\n",
    "fig = plt.figure(figsize=(5, 5))\n",
    "ax = fig.add_subplot(1, 1, 1)\n",
    "ax.matshow(trained_attn, cmap='viridis')\n",
    "ax.set_xticklabels([''] + sentences[0].split(), fontdict={'fontsize': 14})\n",
    "ax.set_yticklabels([''] + sentences[2].split(), fontdict={'fontsize': 14})\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
